UP | HOME

A Little Idea about Middleware

在编写服务的 server 或者 client 时,我们都会用到中间件( middleware )。 中间件极高的复用性以及可组合性极大地拓展了它的用武之地。 本文旨在介绍一种采用 CPS 风格的 middleware 写法。

一般来说, middleware 出于和 handler 可组合性的考量都会将类型设计得尽可能贴近 handler 。

一个极端是 github.com/gin-gonic/gin , middleware 直接采用 handler 的类型。 比如 "github.com/gin-gonic/gin".(*Engine).Use 的定义:

func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes

更一般的, middleware 会将类型设计为类似 func(nextHandler Handler) Handler 的形式。 举几个例子:

// package: "github.com/go-chi/chi/middleware"
func Logger(next http.Handler) http.Handler

// package: "github.com/gorilla/handlers"
func CompressHandler(h http.Handler) http.Handler

// package: "google.golang.org/grpc"
type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)

前两个都很标准。 最后 gRPC 的做法虽然不是那么标准,但是考虑到 middleware (被称为 interceptor )返回了 handler 的返回值,并且也额外接受了 handler 的参数,进行一个类似科里化的变化就能变成那个形式了。

见识到了 func(nextHandler Handler) Handler 这种形式的普遍性后,下面说说这种形式的局限性。 由于对 nextHandler 的调用完全发生在 middleware 当中,所以之后 middleware 发生的 abort 难于让当前 middleware 感受到。 像 gin 一样自行封装 "github.com/gin-gonic/gin".Context 类型并直接提供 "github.com/gin-gonic/gin".(*Context).IsAborted 当然是一种可行的做法——但这要求要么 response 原生就提供 aborted 判断,要么得花额外的工作量来封装新类型。 其实这种局限只是控制流上的局限性,所以可以通过引入 CPS 风格来解决。

首先我们给出 middleware 的类型(注意到这里 Middleware 类型是递归的,想要单独类型递归可以看笔者旧文 Type Recursion in Golang 了解):

type Middleware func(req Request, resp *Response, nextMiddleware Middleware)

func (mw Middleware) Pass(req Request, resp *Response, nextMiddleware Middleware) {
        if mw != nil {
                mw(req, resp, nextMiddleware)
        }
}

这里之所以定义了一个 Middleware.Pass 方法来代替直接调用是因为这样就可以直接用 nil 来表示空 middleware 也就是零元。 这个 Middleware 类型和我们刚才介绍的常见的形式 func(nextHandler Handler) Handler 最大的区别其实就在于将 nextHandler Handler 参数替换成了 nextMiddleware Middleware 。 这样整个 middleware 系统就“自成一体”,(暂时地)抛开了 handler ,也就抛开了 handler 所受的一些限制。

在这个 middleware 系统中, middleware 写法的模板是这样的:

func DemoMiddleware(req Request, resp *Response, next Middleware) {

        DoSomePreJobs()

        if shouldNotProcessReq {
                *resp = someResp
                return
        }

        next.Pass(req, resp, CastMagic)

        DoSomePostJobs()
}

DoSomePreJobsDoSomePostJobs 都很好理解, CastMagic 有什么蹊跷之处么? 当然有, CastMagic 就是用来控制代码执行的控制流的“注入点”。 这么说有点尬吹了,不妨先看两个例子:

func AbortWithFn(f func()) Middleware {
        return func(req Request, resp *Response, next Middleware) {
                f()
        }
}

func LiftFn(f func()) Middleware {
        return func(req Request, resp *Response, next Middleware) {
                f()
                next.Pass(req, resp, nil)
        }
}

在实际使用中,可以用 AbortWithFnLiftFn 构造的 Middleware 来填充 CastMagic 的值,从而达到 abort 的效果或是在确保未进行 abort 的情况下观察 response 。

不妨先看一个简单的 middleware Logger

func Logger(req Request, resp *Response, next Middleware) {

        fmt.Println("Start")
        defer fmt.Println("End")

        next.Pass(req, resp, LiftFn(func() {
                fmt.Printf("Got resp: %q\n", *resp)
        }))
}

实际上运行的时候,这些 middleware 一个接着一个靠 next 给串起来,而 next.Pass(req, resp, LiftFn(...)) 相当于是动态地给当前的 middleware 串又 append 一个新的。 而所有后续的要执行的 middleware 都会包含在靠前的 middleware 对应的 nextMiddleware 里,所以一旦靠前的 middleware 不执行 nextMiddleware ,所有后面的 middleware 就都不会执行了,这也就是 AbortWithFn 的做法。 这里可能有点绕,不过后面的示例代码应该有比较明晰的展示效果。

在展示示例代码之前,还有两件事要做:一是将“(暂时)抛开的” handler 接回来,毕竟前文所述 middleware 已经自成一体了,对于 handler 需要做点工作才能融入这个体系;二是 middleware 的组合如何进行,实际上也是需要点技巧的。

方便起见,我们就挑 func(Request, *Response) 作为 Handler 的定义,并给出相应的用于将 Handler 提升为 Middleware 的函数 LiftHandler

type Handler func(Request, *Response)

func LiftHandler(h Handler) Middleware {
        return func(req Request, resp *Response, next Middleware) {
                h(req, resp)
                next.Pass(req, resp, nil)
        }
}

middleware 的组合当然是从右向左依次进行,不过最基本的组合两个 middleware 的操作会因为采用 CPS 风格而复杂一些。 先假装已经有了组合两个 middleware 的函数 Compose ,看如何拼装多个:

func Chain(mws ...Middleware) func(Handler) Handler {

        r := Middleware(nil)

        for _, mw := range mws {
                r = Compose(r, mw)
        }

        return func(h Handler) Handler {
                return func(req Request, resp *Response) {
                        r.Pass(req, resp, LiftHandler(h))
                }
        }
}

接下来尝试给出 Compose 的实现。 首先回想一下 middleware 和 nextMiddleware 的关系: middleware 将需要交由后续 middleware 托管的代码逻辑通过作为 nextMiddlewarenextMiddleware 传递给 nextMiddleware.Pass() 来往后传递。 也就是说对于本层 middleware 而言, nextMiddleware 就是下一个 middleware ,而自己传给 nextMiddlewarenextMiddleware 却是 append 到目前这个 middleware 串的最末。 所以对于两个 middleware mw1mw2Compose(mw1, mw2) 这个中间件需要对它的 nextMiddleware (不妨称为 mw3 )传入原本 mw1mw2 最后要做的事(不妨称为 next1next2 ),并且顺序是是 next2next1 。 还是有点绕,不过结合代码多看看应该就理解了:

func Compose(mw1 Middleware, mw2 Middleware) Middleware {
        switch {
        case mw1 == nil:
                return mw2
        case mw2 == nil:
                return mw1
        default:
                return func(req Request, resp *Response, mw3 Middleware) {
                        mw1.Pass(req, resp, func(req Request, resp *Response, next1 Middleware) {
                                mw2.Pass(req, resp, func(req Request, resp *Response, next2 Middleware) {
                                        mw3.Pass(req, resp, Compose(next2, next1))
                                })
                        })
                }
        }
}

行文至此,这个 CPS 风格的 middleware 系统基本上就介绍完毕了。 最后上示例来看看效果:

type (
        Request  = int
        Response = string
)

func main() {

        handler := func(req Request, resp *Response) {
                fmt.Println("Handler")
                *resp = strconv.FormatInt(int64(req), 10)
        }

        mws := make([]Middleware, 3)
        for i := range mws {
                i := i
                mws[i] = func(req Request, resp *Response, next Middleware) {
                        fmt.Printf("MW%d start...\n", i)
                        next.Pass(req, resp, LiftFn(func() {
                                *resp = fmt.Sprintf("returned from MW%d: [%s]", i, *resp)
                                fmt.Printf("MW%d end\n", i)
                        }))
                }
        }

        abortAtRespMws := make([]Middleware, len(mws))
        for i := range abortAtRespMws {
                i := i
                abortAtRespMws[i] = func(req Request, resp *Response, next Middleware) {
                        fmt.Printf("MW%d start...\n", i)
                        next.Pass(req, resp, AbortWithFn(func() {
                                fmt.Printf("MW%d abort\n", i)
                        }))
                }
        }

        // print out:
        //      Start
        //      MW0 start...
        //      MW1 start...
        //      MW2 start...
        //      Handler
        //      MW2 end
        //      MW1 end
        //      MW0 end
        //      Got resp: "returned from MW0: [returned from MW1: [returned from MW2: [1]]]"
        //      End
        Chain(Logger, mws[0], mws[1], mws[2])(handler)(1, new(string))

        // print out:
        //      Start
        //      MW0 start...
        //      MW1 start...
        //      MW2 start...
        //      Handler
        //      MW2 end
        //      MW1 end
        //      MW0 abort
        //      End
        Chain(Logger, abortAtRespMws[0], mws[1], mws[2])(handler)(2, new(string))

        // print out:
        //      Start
        //      MW0 start...
        //      MW1 start...
        //      MW2 start...
        //      Handler
        //      MW2 end
        //      MW1 abort
        //      End
        Chain(Logger, mws[0], abortAtRespMws[1], mws[2])(handler)(3, new(string))

        // print out:
        //      Start
        //      MW0 start...
        //      MW1 start...
        //      MW2 start...
        //      Handler
        //      MW2 abort
        //      End
        Chain(Logger, mws[0], mws[1], abortAtRespMws[2])(handler)(4, new(string))
}